from customlensnakeenv_single import SnakeEnv
from stable_baselines3 import PPO

models_dir = f"models/1704738637/"
env_shape = 6
env = SnakeEnv(render=True, board_size=env_shape)
env.reset()

model_path = f"{models_dir}/150000000"
model = PPO.load(model_path, env=env, device="cpu")

episodes = 200
len_arr = []
for ep in range(episodes):
    done = False
    obs, _ = env.reset()
    steps = 0
    length = 3
    while not done:
        reshaped_obs = obs.reshape((1, env_shape * env_shape + 7))
        action, _ = model.predict(reshaped_obs, deterministic=True)
        obs, reward, done, trunc, info = env.step(action)